Skip to content

Conversation

@leafs1
Copy link
Contributor

@leafs1 leafs1 commented Jul 1, 2025

Summary

This PR improves the exp_fp16 and gelu_fp16 tests by using a dynamic tolerance strategy similar to the XNNPACK tolerance calculation for validating float16 exponential kernels. Instead of relying on fixed absolute and relative tolerances, the test now calculates acceptable error bounds based on the output magnitude and float16 precision constraints. This change ensures correctness while accommodating the inherent limitations of float16 arithmetic.

Problem

While testing the float16 exponential kernel from XNNPACK against PyTorch's eager-mode implementation, sparse errors occured. The failures were due to small mismatches between the output values, often in the range of ~0.01 to ~0.015. These discrepancies occurred despite both outputs being reasonably close when viewed through the lens of float16 precision. The original test used fixed tolerance values (atol=1e-3, rtol=1e-3), which were too strict for float16 results, particularly for inputs that produced large exponentials.

Investigation

To understand the failures, I traced specific cases where discrepancies occurred. For example, for the input 2.2715, PyTorch computes exp(2.2715) in float32 and rounds the result to float16, yielding 9.6953. In contrast, XNNPACK uses float16-only arithmetic throughout its kernel, computing a slightly lower value of 9.6797. The difference between the two outputs is exactly 0.0156, which corresponds to one ULP (unit in the last place) at that magnitude in float16. This led me to examine the structure of float16 and its numerical limits in detail.

Further analysis revealed that IEEE 754 half-precision floating point (float16) has a limited resolution — only 10 bits for the significand — meaning the spacing between representable values increases with magnitude. Near 1.0, the ULP is about 0.00098, but near 9.7, it rises to 0.0156. Given this, it became clear that small absolute differences in the output were not only expected but within the bounds of what float16 can actually represent.

To confirm the root cause, I reviewed the XNNPACK source code and documentation. Their float16 exponential kernel uses a 2^z * 2^r decomposition and evaluates a degree-3 polynomial using multiple steps of float16 arithmetic exclusively, which introduces a lot of error. More importantly, I found that XNNPACK’s own test infrastructure accepts outputs within a mixed tolerance of 2 × ε absolute and 6 × ε relative error, where ε ≈ 9.77e-4 is the machine epsilon for float16. This tolerance model is defined by their TolMixed function and effectively allows up to ~6 ULPs of error, depending on the output value.

Solution

This PR updates the exp_fp16 and gelu_fp16 tests to use the same tolerance policy as XNNPACK. For float16 inputs, the test now computes the reference output using float32 precision, then applies the following tolerance calculation:
Absolute tolerance: 2 × ε ≈ 0.00195
Relative tolerance: 6 × ε ≈ 0.00586
Final tolerance per output: max(atol, rtol × |y_ref|)

Test plan

I tested this by adding the new rtol and atol values to the test suite and running the tests with various random inputs to ensure that the tests pass.

@leafs1 leafs1 requested review from digantdesai and mcr229 as code owners July 1, 2025 22:09
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/12150

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 Cancelled Jobs

As of commit 0882ae1 with merge base 75d4b2e (image):

CANCELLED JOBS - The following jobs were cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 1, 2025
@leafs1 leafs1 changed the title Adjust tolerance for fp16 exp op to handle reasonable calculation dis… Adjust tolerance for fp16 exp op test to handle reasonable calculation discrepancies Jul 1, 2025
@leafs1 leafs1 force-pushed the expFp16Fix branch 4 times, most recently from 0972df0 to 78e8ab8 Compare July 1, 2025 23:33
@leafs1
Copy link
Contributor Author

leafs1 commented Jul 1, 2025

@pytorchbot label "release notes: none"

@pytorch-bot pytorch-bot bot added the release notes: none Do not include this in the release notes label Jul 1, 2025
Copy link
Member

@GregoryComer GregoryComer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice analysis! Reasoning and conclusion look good to me.

@leafs1 leafs1 force-pushed the expFp16Fix branch 2 times, most recently from 65976f2 to 8a3afda Compare July 2, 2025 00:29
@leafs1 leafs1 changed the title Adjust tolerance for fp16 exp op test to handle reasonable calculation discrepancies Adjust tolerance for fp16 exp & gelu ops test to handle reasonable calculation discrepancies Jul 2, 2025
@digantdesai
Copy link
Contributor

Good work!

I found that XNNPACK’s own test infrastructure accepts outputs within a mixed tolerance of 2 × ε absolute and 6 × ε relative error, where ε ≈ 9.77e-4 is the machine epsilon for float16

2 and 6 seems arbitrary and could be input range related, did you dig into why these multipliers work in XNNPACK code?

@leafs1 leafs1 merged commit 29858b4 into pytorch:main Jul 3, 2025
165 of 169 checks passed
Tanish2101 pushed a commit to Tanish2101/executorch that referenced this pull request Jul 9, 2025
…lculation discrepancies (pytorch#12150)

### Summary
This PR improves the exp_fp16 and gelu_fp16 tests by using a dynamic
tolerance strategy similar to the XNNPACK tolerance calculation for
validating float16 exponential kernels. Instead of relying on fixed
absolute and relative tolerances, the test now calculates acceptable
error bounds based on the output magnitude and float16 precision
constraints. This change ensures correctness while accommodating the
inherent limitations of float16 arithmetic.

### Problem
While testing the float16 exponential kernel from XNNPACK against
PyTorch's eager-mode implementation, sparse errors occured. The failures
were due to small mismatches between the output values, often in the
range of ~0.01 to ~0.015. These discrepancies occurred despite both
outputs being reasonably close when viewed through the lens of float16
precision. The original test used fixed tolerance values (atol=1e-3,
rtol=1e-3), which were too strict for float16 results, particularly for
inputs that produced large exponentials.

### Investigation
To understand the failures, I traced specific cases where discrepancies
occurred. For example, for the input 2.2715, PyTorch computes
exp(2.2715) in float32 and rounds the result to float16, yielding
9.6953. In contrast, XNNPACK uses float16-only arithmetic throughout its
kernel, computing a slightly lower value of 9.6797. The difference
between the two outputs is exactly 0.0156, which corresponds to one ULP
(unit in the last place) at that magnitude in float16. This led me to
examine the structure of float16 and its numerical limits in detail.

Further analysis revealed that IEEE 754 half-precision floating point
(float16) has a limited resolution — only 10 bits for the significand —
meaning the spacing between representable values increases with
magnitude. Near 1.0, the ULP is about 0.00098, but near 9.7, it rises to
0.0156. Given this, it became clear that small absolute differences in
the output were not only expected but within the bounds of what float16
can actually represent.

To confirm the root cause, I reviewed the XNNPACK source code and
documentation. Their float16 exponential kernel uses a 2^z * 2^r
decomposition and evaluates a degree-3 polynomial using multiple steps
of float16 arithmetic exclusively, which introduces a lot of error. More
importantly, I found that XNNPACK’s own test infrastructure accepts
outputs within a mixed tolerance of 2 × ε absolute and 6 × ε relative
error, where ε ≈ 9.77e-4 is the machine epsilon for float16. This
tolerance model is defined by their TolMixed function and effectively
allows up to ~6 ULPs of error, depending on the output value.

### Solution
This PR updates the exp_fp16 and gelu_fp16 tests to use the same
tolerance policy as XNNPACK. For float16 inputs, the test now computes
the reference output using float32 precision, then applies the following
tolerance calculation:
Absolute tolerance: 2 × ε ≈ 0.00195
Relative tolerance: 6 × ε ≈ 0.00586
Final tolerance per output: max(atol, rtol × |y_ref|)

### Test plan
I tested this by adding the new rtol and atol values to the test suite
and running the tests with various random inputs to ensure that the
tests pass.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: none Do not include this in the release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants